Toxic Comment Filter

BiLSTM model for multi label classification
code
Deep Learning
Python, R
Author

Simone Brazzi

Published

August 12, 2024

Introduction

  • Costruire un modello in grado di filtrare i commenti degli utenti in base al grado di dannosità del linguaggio.
  • Preprocessare il testo eliminando l’insieme di token che non danno contributo significativo a livello semantico.
  • Trasformare il corpus testuale in sequenze.
  • Costruire un modello di Deep Learning comprendente dei layer ricorrenti per un task di classificazione multilabel.

In prediction time, il modello deve ritornare un vettore contenente un 1 o uno 0 in corrispondenza di ogni label presente nel dataset (toxic, severe_toxic, obscene, threat, insult, identity_hate). In questo modo, un commento non dannoso sarà classificato da un vettore di soli 0 [0,0,0,0,0,0]. Al contrario, un commento pericoloso presenterà almeno un 1 tra le 6 labels.

Setup

Leveraging Quarto and RStudio, I will setup an R and Python enviroment.

Import R libraries

Import R libraries. These will be used for both the rendering of the document and data analysis. The reason is I prefer ggplot2 over matplotlib. I will also use colorblind safe palettes.

Code
library(tidyverse, verbose = FALSE)
library(tidymodels, verbose = FALSE)
library(reticulate)
library(ggplot2)
library(plotly)
library(RColorBrewer)
library(bslib)
library(Metrics)

reticulate::use_virtualenv("r-tf")

Import Python packages

Code
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import keras
import keras_nlp

from keras.backend import clear_session
from keras.models import Model, load_model
from keras.layers import TextVectorization, Input, Dense, Embedding, Dropout, GlobalAveragePooling1D, LSTM, Bidirectional, GlobalMaxPool1D, Flatten, Attention
from keras.metrics import Precision, Recall, AUC, SensitivityAtSpecificity, SpecificityAtSensitivity, F1Score


from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import multilabel_confusion_matrix, classification_report, ConfusionMatrixDisplay, precision_recall_curve, f1_score, recall_score, roc_auc_score

Create a Config class to store all the useful parameters for the model and for the project.

Class Config

I created a class with all the basic configuration of the model, to improve the readability.

Code
class Config():
    def __init__(self):
        self.url = "https://s3.eu-west-3.amazonaws.com/profession.ai/datasets/Filter_Toxic_Comments_dataset.csv"
        self.max_tokens = 20000
        self.output_sequence_length = 911 # check the analysis done to establish this value
        self.embedding_dim = 128
        self.batch_size = 32
        self.epochs = 100
        self.temp_split = 0.3
        self.test_split = 0.5
        self.random_state = 42
        self.total_samples = 159571 # total train samples
        self.train_samples = 111699
        self.val_samples = 23936
        self.features = 'comment_text'
        self.labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
        self.new_labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate', "clean"]
        self.label_mapping = {label: i for i, label in enumerate(self.labels)}
        self.new_label_mapping = {label: i for i, label in enumerate(self.labels)}
        self.path = "/Users/simonebrazzi/R/blog/posts/toxic_comment_filter/history/f1score/"
        self.model =  self.path + "model_f1.keras"
        self.checkpoint = self.path + "checkpoint.lstm_model_f1.keras"
        self.history = self.path + "lstm_model_f1.xlsx"
        
        self.metrics = [
            Precision(name='precision'),
            Recall(name='recall'),
            AUC(name='auc', multi_label=True, num_labels=len(self.labels)),
            F1Score(name="f1", average="macro")
            
        ]
    def get_early_stopping(self):
        early_stopping = keras.callbacks.EarlyStopping(
            monitor="val_f1", # "val_recall",
            min_delta=0.2,
            patience=10,
            verbose=0,
            mode="max",
            restore_best_weights=True,
            start_from_epoch=3
        )
        return early_stopping

    def get_model_checkpoint(self, filepath):
        model_checkpoint = keras.callbacks.ModelCheckpoint(
            filepath=filepath,
            monitor="val_f1", # "val_recall",
            verbose=0,
            save_best_only=True,
            save_weights_only=False,
            mode="max",
            save_freq="epoch"
        )
        return model_checkpoint

    def find_optimal_threshold_cv(self, ytrue, yproba, metric, thresholds=np.arange(.05, .35, .05), n_splits=7):

      # instantiate KFold
      kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
      threshold_scores = []

      for threshold in thresholds:

        cv_scores = []
        for train_index, val_index in kf.split(ytrue):

          ytrue_val = ytrue[val_index]
          yproba_val = yproba[val_index]

          ypred_val = (yproba_val >= threshold).astype(int)
          score = metric(ytrue_val, ypred_val, average="macro")
          cv_scores.append(score)

        mean_score = np.mean(cv_scores)
        threshold_scores.append((threshold, mean_score))

        # Find the threshold with the highest mean score
        best_threshold, best_score = max(threshold_scores, key=lambda x: x[1])
      return best_threshold, best_score

config = Config()

Data

The dataset is accessible using tf.keras.utils.get_file to get the file from the url. N.B. For reproducibility purpose, I also downloaded the dataset. There was time in which the link was not available.

Code
# df = pd.read_csv(config.path)
file = tf.keras.utils.get_file("Filter_Toxic_Comments_dataset.csv", config.url)
df = pd.read_csv(file)
Code
library(reticulate)

py$df %>%
  tibble() %>% 
  head(5)
Table 1: First 5 elemtns
# A tibble: 5 × 8
  comment_text            toxic severe_toxic obscene threat insult identity_hate
  <chr>                   <dbl>        <dbl>   <dbl>  <dbl>  <dbl>         <dbl>
1 "Explanation\nWhy the …     0            0       0      0      0             0
2 "D'aww! He matches thi…     0            0       0      0      0             0
3 "Hey man, I'm really n…     0            0       0      0      0             0
4 "\"\nMore\nI can't mak…     0            0       0      0      0             0
5 "You, sir, are my hero…     0            0       0      0      0             0
# ℹ 1 more variable: sum_injurious <dbl>

Lets create a clean variable for EDA purpose: I want to visually see how many observation are clean vs the others labels.

Code
df.loc[df.sum_injurious == 0, "clean"] = 1
df.loc[df.sum_injurious != 0, "clean"] = 0

EDA

First a check on the dataset to find possible missing values and imbalances.

Frequency

Code
library(reticulate)
df_r <- py$df
new_labels_r <- py$config$new_labels

df_r_grouped <- df_r %>% 
  select(all_of(new_labels_r)) %>%
  pivot_longer(
    cols = all_of(new_labels_r),
    names_to = "label",
    values_to = "value"
  ) %>% 
  group_by(label) %>%
  summarise(count = sum(value)) %>% 
  mutate(freq = round(count / sum(count), 4))

df_r_grouped
Table 2: Absolute and relative labels frequency
# A tibble: 7 × 3
  label          count   freq
  <chr>          <dbl>  <dbl>
1 clean         143346 0.803 
2 identity_hate   1405 0.0079
3 insult          7877 0.0441
4 obscene         8449 0.0473
5 severe_toxic    1595 0.0089
6 threat           478 0.0027
7 toxic          15294 0.0857

Barchart

Code
library(reticulate)
barchart <- df_r_grouped %>%
  ggplot(aes(x = reorder(label, count), y = count, fill = label)) +
  geom_col() +
  labs(
    x = "Labels",
    y = "Count"
  ) +
  # sort bars in descending order
  scale_x_discrete(limits = df_r_grouped$label[order(df_r_grouped$count, decreasing = TRUE)]) +
  scale_fill_brewer(type = "seq", palette = "RdYlBu")
ggplotly(barchart)
Figure 1: Imbalance in the dataset with clean variable

It is visible how much the dataset in imbalanced. This means it could be useful to check for the class weight and use this argument during the training.

It is clear that most of our text are clean. We are talking about 0.8033 of the observations which are clean. Only 0.1967 are toxic comments.

Sequence lenght definition

To convert the text in a useful input for a NN, it is necessary to use a TextVectorization layer. See the Section 4 section.

One of the method is output_sequence_length: to better define it, it is useful to analyze our text length. To simulate what the model we do, we are going to remove the punctuation and the new lines from the comments.

Summary

Code
library(reticulate)
df_r %>% 
  mutate(
    comment_text_clean = comment_text %>%
      tolower() %>% 
      str_remove_all("[[:punct:]]") %>% 
      str_replace_all("\n", " "),
    text_length = comment_text_clean %>% str_count()
    ) %>% 
  pull(text_length) %>% 
  summary() %>% 
  as.list() %>% 
  as_tibble()
Table 3: Summary of text length
# A tibble: 1 × 6
   Min. `1st Qu.` Median  Mean `3rd Qu.`  Max.
  <dbl>     <dbl>  <dbl> <dbl>     <dbl> <dbl>
1     4        91    196  378.       419  5000

Boxplot

Code
library(reticulate)
boxplot <- df_r %>% 
  mutate(
    comment_text_clean = comment_text %>%
      tolower() %>% 
      str_remove_all("[[:punct:]]") %>% 
      str_replace_all("\n", " "),
    text_length = comment_text_clean %>% str_count()
    ) %>% 
  # pull(text_length) %>% 
  ggplot(aes(y = text_length)) +
  geom_boxplot() +
  theme_minimal()
ggplotly(boxplot)
Figure 2: Text length boxplot

Histogram

Code
library(reticulate)
df_ <- df_r %>% 
  mutate(
    comment_text_clean = comment_text %>%
      tolower() %>% 
      str_remove_all("[[:punct:]]") %>% 
      str_replace_all("\n", " "),
    text_length = comment_text_clean %>% str_count()
  )

Q1 <- quantile(df_$text_length, 0.25)
Q3 <- quantile(df_$text_length, 0.75)
IQR <- Q3 - Q1
upper_fence <- as.integer(Q3 + 1.5 * IQR)

histogram <- df_ %>% 
  ggplot(aes(x = text_length)) +
  geom_histogram(bins = 50) +
  geom_vline(aes(xintercept = upper_fence), color = "red", linetype = "dashed", linewidth = 1) +
  theme_minimal() +
  xlab("Text Length") +
  ylab("Frequency") +
  xlim(0, max(df_$text_length, upper_fence))
ggplotly(histogram)
Figure 3: Text length histogram with boxplot upper fence

Considering all the above analysis, I think a good starting value for the output_sequence_length is 911, the upper fence of the boxplot. In the last plot, it is the dashed red vertical line.. Doing so, we are removing the outliers, which are a small part of our dataset.

Dataset

Now we can split the dataset in 3: train, test and validation sets. Considering there is not a function in sklearn which lets split in these 3 sets, we can do the following: - split between a train and temporary set with a 0.3 split. - split the temporary set in 2 equal sized test and val sets.

Code
x = df[config.features].values
y = df[config.labels].values

xtrain, xtemp, ytrain, ytemp = train_test_split(
  x,
  y,
  test_size=config.temp_split, # .3
  random_state=config.random_state
  )
xtest, xval, ytest, yval = train_test_split(
  xtemp,
  ytemp,
  test_size=config.test_split, # .5
  random_state=config.random_state
  )

xtrain shape: py$xtrain.shape ytrain shape: py$ytrain.shape xtest shape: py$xtest.shape ytest shape: py$ytest.shape xval shape: py$xval.shape yval shape: py$yval.shape

The datasets are created using the tf.data.Dataset function. It creates a data input pipeline. The tf.data API makes it possible to handle large amounts of data, read from different data formats, and perform complex transformations. The tf.data.Dataset is an abstraction that represents a sequence of elements, in which each element consists of one or more components. Here each dataset is creates using from_tensor_slices. It create a tf.data.Dataset from a tuple (features, labels). .batch let us work in batches to improve performance, while .prefetch overlaps the preprocessing and model execution of a training step. While the model is executing training step s, the input pipeline is reading the data for step s+1. Check the documentation for further informations.

Code
train_ds = (
    tf.data.Dataset
    .from_tensor_slices((xtrain, ytrain))
    .shuffle(xtrain.shape[0])
    .batch(config.batch_size)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

test_ds = (
    tf.data.Dataset
    .from_tensor_slices((xtest, ytest))
    .batch(config.batch_size)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

val_ds = (
    tf.data.Dataset
    .from_tensor_slices((xval, yval))
    .batch(config.batch_size)
    .prefetch(tf.data.experimental.AUTOTUNE)
)
Code
print(
  f"train_ds cardinality: {train_ds.cardinality()}\n",
  f"val_ds cardinality: {val_ds.cardinality()}\n",
  f"test_ds cardinality: {test_ds.cardinality()}\n"
  )
train_ds cardinality: 3491
 val_ds cardinality: 748
 test_ds cardinality: 748

Check the first element of the dataset to be sure that the preprocessing is done correctly.

Code
train_ds.as_numpy_iterator().next()
(array([b'"The so called reliable sources are opinion based and are no more importent then your or mine. And you are not getting the point of what Mark Jansen is saying. He dose not have a problem with it because he can\'t change there minds. \n\n""On top of that you are ""Epica uses a ""trademark of many symphonic and gothic metal bands"" in contrasting ""two extremes, death grunts and brutality on one side, airy female melodiousness on the other.""\n\nHere\'s The Project Hate http://www.youtube.com/watch?v=8mJk6hD9czY\nShocking the use brutality death grunts and female melodiousness voacals to. Yet they are Death meatl.\n\n""attraction ultimately hinges on exploring the sonic contrasts of light and dark; the punishing intensity of those elephantine guitar riffs and hyperactive drumming cast against the soaring, layered sweetness of the orchestrated strings and keyboards.""\n\nYeah them and evet other Meatl band that dose that to Ram-Zet http://www.youtube.com/watch?v=gTTUP43nq9k\n\n""The group is also known to employ human choirs and orchestras with additional embellishments such as spoken word recitals and lyrics in Latin and Arabic""\n\nI can name so many metal bands that do that. That is it not something that Gothic Metal just uses.\n\nAnd again I will say knowing your subject makes it much more easy to edit a page. But going by ever thing you say I should go Edit the Kenny G page to what ever the so called reliable sources. Reviews are opinions ever one has them and reviews from so called reliable sources mean jack.\n\n"',
       b"I'll be spending the 48 hours whittling down my watch list, as that's what got me into trouble the last time.",
       b'":::I had a response here... I don\'t know what happened to it, if i didn\'t save it or what.  Frustrating! I might find it eventually.  The point of it, basically, was: Personally, I would not be the least bit offended if someone said that to me.  Up to the ""I am also chastised by..."" part, that is.  There, I sympathize, as noone should like to nor deserves to be viewed in black and white. talk 22:09, 2005 Mar 20 (UTC)\n\n"',
       b'There were ample sources available simply by googling for DNA melting Temperature. I was only bothered to bring three references but I guess a theory can only be proven wrong, can never be proven right for all situations but then again you should know that. By the way, if you do agree with this information, can you please tell me how to incorporate into the article? Thank you.',
       b'Reply: Vandalism \n\nOk.',
       b"October 2011 \n\nPlease stop adding content with unreliable sources at 2012 AFC President's Cup.  The revisions that you have made also do not apply to the 2012 AFC President's Cup.",
       b'" and misinterpreting it deliberately (or you\'re not, which then, well, god have mercy)\nWe\'re at the point where, you\'re simply not taking in any points i\'ve raised and dispel unfavorable ones as ""invalid"", and this\'ll attrite till it\'s pearly gates and trumpet sounds. i\'m disinclined to concede as well so, it\'s probably a good time for a third opinion"',
       b'MOTHER FUCKING INBREED PIECES OF SHIT',
       b'for potexionn, block dis page....MWAAAAAAAAAAAAAAAAAAAAAHAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAHAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAHAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAHAAAAAAAAAAAAAAAAAAAAAAA!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!',
       b'Separate article for Moynihan Station? \n\nI think the time has come to give moynihan Station its own entry.  There is enough information, and released sketches for this.',
       b'"\n\n ""I do not have a personal connection, either by family, descent, or religion, with the people and places I write about in Wikipedia."" \n\nI believe you; thousands wouldn\'t. ~   "',
       b'"\n\n Alerting you to a newly opened wiki alert \n\nI have opened a wiki alert here. |\xc2\xa0\xe2\x96\xbaChat\xc2\xa0 "',
       b"Aklauncher, I feel sad because Wikipedia doesn't have an article about The Beta Adrenergic Theory thats been cited much more than the Hygiene Hypothasis.  I would be most gratful if you would write about The Beta Adrenergic Theory based partly on the sub-article under Asthma.\nMy dad Andor Szentivanyi discovered it and it's been cited by such famous people as Craig Venter {The Human Genome Project} and Charles Reed {Chief of Allergy, The Mayo Clinic}.",
       b'terminology that can be construed as',
       b'I have added to the discussion to make the relationship more direct.',
       b'I HATE YOU! \n\n I HATE YOU',
       b"I've seen that threat at the start of DVDs, with no distinction made between uploads and downloads.  It may very well be that they only prosecute uploads, but I'm also opposed to having overly strict laws and leaving it up to the authorities to decide who to prosecute, as this gives them too much power over us.  I think I will add this to the list.",
       b"Well, then I do know that Frye said Azeri Turks were the founders of Safavid dynasty, but I do not list it exactly as said. And the point of poetry, a person can be fascinated by Theodore Dreiser and name his daughter Carrie, it does not mean that person claims descendance from Americans. So the argument that he claimed or believed to be descendant of Royal Persians does not have ground, especially given the fact that Ismail claimed being a descendant of Ali ibn Abu Talib. Similarly, we don't claim Ismail as a pure Turk, based on the fact that he wrote primarily in Turkish.",
       b"sorry \n\nFor stuffing the battery page - mt internet went down during a save I have fixed it at an internet cafe  chad\n\nNo worries. Another semi-common event is that certain browsers will only edit 32KB of text (which is why the Wiki software warns you whenever a page gets longer). If you try to edit the article with these browsers, they have a nasty habit of cutting off the article at th3 32,769th character. (That's what I figured had happened to you.)",
       b'", 26 August 2010 (UTC)\nThanks, also could you delete Template:NASCAR Winston Cup driver it was repplaced by Template:Infobox NASCAR driver. Thanks. 1996  22:22"',
       b'Rebuild \nI enjoyed this page when I first saw it about a year ago.  I have some of the public domain pictures that were originally posted here but have now been deleted.  It has been sometime since the originator has posted.  I just recently became a member of Wikipedia and am still learning how to post things.  I will try to replace some of the pictures that have been deleted.  I also think it is a good idea for anyone else to add their own examples of Korean pieces if they feel it will add to the educational value of this site.  I have some of my own that I plan to add. I believe that it is in the best interest of Wikipedia to build up and not tear down.\n 5 Oct 2008',
       b'"::: For the remainder of my adoption and if I were to be unblock, I\'d be more than happy to remove my message from his talk. \'\'\'\'\'\' (Chat \xe2\x80\xa2 Count) \n\n"',
       b"What's the objection to starting a sentence with a number?",
       b'"\n\nThat fellow Editor uses few aliases like rpo.castro when he dislikes others contributuons. Also uses confusion by inducing ""changing of tournament sceme"" every once in a while ... very often. And thas all he contributes. After he see he\'s nazi fires and deletions are wrong he uses parts of it in his ""later editions"". Funny editor ... But hey he\'s good for wikis profit isnt it if you resell content."',
       b'"\n\n Image tagging \n\nHi, Pece. Thanks for the kind note! Sorry I haven\'t been able to answer your messages. I\'ve had a lot of work and studying to complete, which is why I\'m stressed. As for the images, you\'ll find all the tags at Wikipedia:Image copyright tags. The information needed should be on the websites and on the aforementioned Wikipedia page, but if you need any help, just leave a note at my talk page. Regards, 123 (talk) "',
       b'Evanescence new album took 3 years to make \n\ni think amy took so long on writing her songs because she was dealing with her breakup. i dont think it had anything to do with band members leaving.???',
       b'Thanks for informing violation policies. I will take care of it.',
       b'Please be informed that , like fellow Eli Soriano follower Dar book, has been permanently blocked from editing Wikipedia due to abuse of account privileges (i.e. using a sock puppet to whitewash a duly-sourced article). \n\nWikipedia is, above all, an encyclopedia, and its articles should only use reliable, third-party, published sources. Editors like  and  seek to undermine the integrity of those reliable sources by using all sorts of trickery and deception to inject unsourced promotional bits and pieces in an obvious attempt to sanitize the article which, I believe, is reflective of the ways and character of their leader and organization as a whole. \n\nPlease stay vigilant as incarnations of Petersantos a.k.a. Journeyist a.k.a. Patrick Lim a.k.a. OQtE5AokaDv4 etc. are destined to resurface from the inner recesses of what seems to be the most notorious pseudo-Christian movement in the Philippines. -',
       b'Sequel\nwill there be a sequel',
       b'"\nThere are other valid interpretations (what\'s more interpretations that are backed up by sources). Bush didn\'t want to spread panic , and if he had stormed out of the room the footage would have been replayed endlessly on television and it wouldn\'t have done much for national morale.  "',
       b"Dr. Cirt Joined (3) \n\nI Promise, Unblock it, and then i won't vandalise the page anymore.",
       b'"\nLooie...point taken.  But this is my anti-woo defense mechanism always pops up, even where we don\'t really have to know the mechanism (hell, most of treatment of depression is magic).  OK, omega 3 is far above magical water cures, I admit.   Talk\xe2\x80\xa2 Contributions "'],
      dtype=object), array([[0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [1, 1, 1, 0, 1, 0],
       [1, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0]]))

And we check also the shape. We expect a feature of shape (batch, ) and a target of shape (batch, number of labels).

Code
print(
  f"text train shape: {train_ds.as_numpy_iterator().next()[0].shape}\n",
  f" text train type: {train_ds.as_numpy_iterator().next()[0].dtype}\n",
  f"label train shape: {train_ds.as_numpy_iterator().next()[1].shape}\n",
  f"label train type: {train_ds.as_numpy_iterator().next()[1].dtype}\n"
  )
text train shape: (32,)
  text train type: object
 label train shape: (32, 6)
 label train type: int64

Preprocessing

Of course preprocessing! Text is not the type of input a NN can handle. The TextVectorization layer is meant to handle natural language inputs. The processing of each example contains the following steps: 1. Standardize each example (usually lowercasing + punctuation stripping) 2. Split each example into substrings (usually words) 3. Recombine substrings into tokens (usually ngrams) 4. Index tokens (associate a unique int value with each token) 5. Transform each example using this index, either into a vector of ints or a dense float vector.

For more reference, see the documentation at the following link.

Code
text_vectorization = TextVectorization(
  max_tokens=config.max_tokens,
  standardize="lower_and_strip_punctuation",
  split="whitespace",
  output_mode="int",
  output_sequence_length=config.output_sequence_length,
  pad_to_max_tokens=True
  )

# prepare a dataset that only yields raw text inputs (no labels)
text_train_ds = train_ds.map(lambda x, y: x)
# adapt the text vectorization layer to the text data to index the dataset vocabulary
text_vectorization.adapt(text_train_ds)

This layer is set to: - max_tokens: 20000. It is common for text classification. It is the maximum size of the vocabulary for this layer. - output_sequence_length: 911. See Figure 3 for the reason why. Only valid in "int" mode. - output_mode: outputs integer indices, one integer index per split string token. When output_mode == “int”, 0 is reserved for masked locations; this reduces the vocab size to max_tokens - 2 instead of max_tokens - 1. - standardize: "lower_and_strip_punctuation". - split: on whitespace.

To preserve the original comments as text and also have a tf.data.Dataset in which the text is preprocessed by the TextVectorization function, it is possible to map it to the features of each dataset.

Code
processed_train_ds = train_ds.map(
    lambda x, y: (text_vectorization(x), y),
    num_parallel_calls=tf.data.experimental.AUTOTUNE
)
processed_val_ds = val_ds.map(
    lambda x, y: (text_vectorization(x), y),
    num_parallel_calls=tf.data.experimental.AUTOTUNE
)
processed_test_ds = test_ds.map(
    lambda x, y: (text_vectorization(x), y),
    num_parallel_calls=tf.data.experimental.AUTOTUNE
)

Model

Definition

Define the model using the Functional API.

Code
def get_deeper_lstm_model():
    clear_session()
    inputs = Input(shape=(None,), dtype=tf.int64, name="inputs")
    embedding = Embedding(
        input_dim=config.max_tokens,
        output_dim=config.embedding_dim,
        mask_zero=True,
        name="embedding"
    )(inputs)
    x = Bidirectional(LSTM(256, return_sequences=True, name="bilstm_1"))(embedding)
    x = Bidirectional(LSTM(128, return_sequences=True, name="bilstm_2"))(x)
    # Global average pooling
    x = GlobalAveragePooling1D()(x)
    # Add regularization
    x = Dropout(0.3)(x)
    x = Dense(64, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01))(x)
    x = LayerNormalization()(x)
    outputs = Dense(len(config.labels), activation='sigmoid', name="outputs")(x)
    model = Model(inputs, outputs)
    model.compile(optimizer='adam', loss="binary_crossentropy", metrics=config.metrics, steps_per_execution=32)
    
    return model

lstm_model = get_deeper_lstm_model()
lstm_model.summary()

Callbacks

Finally, the model has been trained using 2 callbacks: - Early Stopping, to avoid to consume the kaggle GPU time. - Model Checkpoint, to retrieve the best model training information.

Code
# callbacks
my_es = config.get_early_stopping()
my_mc = config.get_model_checkpoint(filepath="/checkpoint.keras")
callbacks = [my_es, my_mc]

Final preparation before fit

Considering the dataset is imbalanced, to increase the performance we need to calculate the class weight. This will be passed during the training of the model.

Code
lab = pd.DataFrame(columns=config.labels, data=ytrain)
r = lab.sum() / len(ytrain)
class_weight = dict(zip(range(len(config.labels)), r))
df_class_weight = pd.DataFrame.from_dict(
  data=class_weight,
  orient='index',
  columns=['class_weight']
  )
df_class_weight.index = config.labels
Code
library(reticulate)
py$df_class_weight
Table 4: Class weight
              class_weight
toxic          0.095900590
severe_toxic   0.009928468
obscene        0.052757858
threat         0.003061800
insult         0.049132042
identity_hate  0.008710911

It is also useful to define the steps per epoch for train and validation dataset. This step is required to avoid to not consume entirely the dataset during the fit, which happened to me.

Code
steps_per_epoch = config.train_samples // config.batch_size
validation_steps = config.val_samples // config.batch_size

Fit

The fit has been done on Kaggle to levarage the GPU. Some considerations about the model:

  • .repeat() ensure the model sees all the dataset.
  • epocs is set to 100.
  • validation_data has the same repeat.
  • callbacks are the one defined before.
  • class_weight ensure the model is trained using the frequency of each class, because our dataset is imbalanced.
  • steps_per_epoch and validation_steps depend on the use of repeat.
Code
history = model.fit(
  processed_train_ds.repeat(),
  epochs=config.epochs,
  validation_data=processed_val_ds.repeat(),
  callbacks=callbacks,
  class_weight=class_weight,
  steps_per_epoch=steps_per_epoch,
  validation_steps=validation_steps
  )

Now we can import the model and the history trained on Kaggle.

Code
model = load_model(filepath=config.model)
history = pd.read_excel(config.history)

Evaluate

Code
validation = model.evaluate(
  processed_val_ds.repeat(),
  steps=validation_steps, # 748
  verbose=0
  )
Code
tibble(
  metric = c("loss", "precision", "recall", "auc", "f1_score"),
  value = py$validation
  )
Table 5: Model validation metric
# A tibble: 5 × 2
  metric     value
  <chr>      <dbl>
1 loss      0.0542
2 precision 0.789 
3 recall    0.671 
4 auc       0.957 
5 f1_score  0.0293

Predict

For the prediction, the model does not need to repeat the dataset, because it has already been trained on all of the train data. Now it has just to consume the new data to make the prediction.

Code
predictions = model.predict(processed_test_ds, verbose=0)

Confusion Matrix

The best way to assess the performance of a multi label classification is using a confusion matrix. Sklearn has a specific function to create a multi label classification matrix to handle the fact that there could be multiple labels for one prediction.

Grid Search Cross Validation for best threshold

Grid Search CV is a technique for fine-tuning hyperparameter of a ML model. It systematically search through a set of hyperparamenter values to find the combination which led to the best model performance. In this case, I am using a KFold Cross Validation is a resempling technique to split the data into k consecutive folds. Each fold is used once as a validation while the k - 1 remaining folds are the training set. See the documentation for more information.

The model is trained to optimize the recall. The decision was made because the cost of missing a True Positive is greater than a False Positive. In this case, missing a injurious observation is worst than classifying a clean one as bad.

Confidence threshold and Precision-Recall trade off

Whilst the KFold GDCV technique is usefull to test multiple hyperparameter, it is important to understand the problem we are facing. A multi label deep learning classifier outputs a vector of per-class probabilities. These need to be converted to a binary vector using a confidence threshold.

  • The higher the threshold, the less classes the model predicts, increasing model confidence [higher Precision] and increasing missed classes [lower Recall].
  • The lower the threshold, the more classes the model predicts, decreasing model confidence [lower Precision] and decreasing missed classes [higher Recall].

Threshold selection mean we have to decide which metric to prioritize, based on the problem we are facing and the relative cost of misduging. We can consider the toxic comment filtering a problem similiar to cancer diagnostic. It is better to predict cancer in people who do not have it [False Positive] and perform further analysis than do not predict cancer when the patient has the disease [False Negative].

I decide to train the model on the F1 score to have a balanced model in both precision and recall and leave to the threshold selection to increase the recall performance.

Moreover, the model has been trained on the macro avarage F1 score, which is a single performance indicator obtained by the mean of the Precision and Recall scores of individual classses.

It is usegule for imbalanced classes, because it weights each classes equally. It is not influenced by the number of samples of each classes. This is sette both in the config.metrics and find_optimal_threshold_cv.

f1_score

Code
ytrue = ytest.astype(int)
y_pred_proba = predictions
optimal_threshold_f1, best_score_f1 = config.find_optimal_threshold_cv(ytrue, y_pred_proba, f1_score)

print(f"Optimal threshold: {optimal_threshold_f1}")
Optimal threshold: 0.15000000000000002
Code
print(f"Best score: {best_score_f1}")
Best score: 0.4788653077945807
Code

# Use the optimal threshold to make predictions
final_predictions_f1 = (y_pred_proba >= optimal_threshold_f1).astype(int)

Optimal threshold f1 score: 0.15. Best score: 0.4788653.

recall_score

Code
ytrue = ytest.astype(int)
y_pred_proba = predictions
optimal_threshold_recall, best_score_recall = config.find_optimal_threshold_cv(ytrue, y_pred_proba, recall_score)

# Use the optimal threshold to make predictions
final_predictions_recall = (y_pred_proba >= optimal_threshold_recall).astype(int)

Optimal threshold recall: 0.05. Best score: 0.8095814.

roc_auc_score

Code
ytrue = ytest.astype(int)
y_pred_proba = predictions
optimal_threshold_roc, best_score_roc = config.find_optimal_threshold_cv(ytrue, y_pred_proba, roc_auc_score)

print(f"Optimal threshold: {optimal_threshold_roc}")
Optimal threshold: 0.05
Code
print(f"Best score: {best_score_roc}")
Best score: 0.8809499649742268
Code

# Use the optimal threshold to make predictions
final_predictions_roc = (y_pred_proba >= optimal_threshold_roc).astype(int)

Optimal threshold roc: 0.05. Best score: 0.88095.

Confusion Matrix Plot

Code
# convert probability predictions to predictions
ypred = predictions >=  optimal_threshold_recall # .05
ypred = ypred.astype(int)

# create a plot with 3 by 2 subplots
fig, axes = plt.subplots(3, 2, figsize=(15, 15))
axes = axes.flatten()
mcm = multilabel_confusion_matrix(ytrue, ypred)
# plot the confusion matrices for each label
for i, (cm, label) in enumerate(zip(mcm, config.labels)):
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot(ax=axes[i], colorbar=False)
    axes[i].set_title(f"Confusion matrix for label: {label}")
plt.tight_layout()
plt.show()
Figure 4: Multi Label Confusion matrix

Classification Report

Code
cr = classification_report(
  ytrue,
  ypred,
  target_names=config.labels,
  digits=4,
  output_dict=True
  )
df_cr = pd.DataFrame.from_dict(cr).reset_index()
Code
library(reticulate)
df_cr <- py$df_cr %>% dplyr::rename(names = index)
cols <- df_cr %>% colnames()
df_cr %>% 
  pivot_longer(
    cols = -names,
    names_to = "metrics",
    values_to = "values"
  ) %>% 
  pivot_wider(
    names_from = names,
    values_from = values
  )
Table 6: Classification report
# A tibble: 10 × 5
   metrics       precision recall `f1-score` support
   <chr>             <dbl>  <dbl>      <dbl>   <dbl>
 1 toxic            0.552  0.890      0.682     2262
 2 severe_toxic     0.236  0.917      0.375      240
 3 obscene          0.550  0.936      0.692     1263
 4 threat           0.0366 0.493      0.0681      69
 5 insult           0.471  0.915      0.622     1170
 6 identity_hate    0.116  0.720      0.200      207
 7 micro avg        0.416  0.896      0.569     5211
 8 macro avg        0.327  0.812      0.440     5211
 9 weighted avg     0.495  0.896      0.629     5211
10 samples avg      0.0502 0.0848     0.0597    5211

Conclusions

The BiLSTM model is optimized to have an high recall is performing good enough to make predictions for each label. Considering the low support for the threat label, the performance is not bad. See Table 2 and Figure 1: the threat label is only 0.27 % of the observations. The model has been optimized for recall because the cost of not identifying a injurious comment as such is higher than the cost of considering a clean comment as injurious.

Possibile improvements could be to increase the number of observations, expecially for the threat one. In general there are too many clean comments. This could be avoided doing an undersampling of the clean comment, which I explicitly avoided to check the performance on the BiLSTM with an imbalanced dataset, leveraging the class weight method.